# ACUPUNCTURE CATE Analysis

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
import warnings
import sys
from datetime import datetime
warnings.filterwarnings('ignore')

class TeeOutput:
    """Class to write output to both console and file simultaneously."""
    def __init__(self, filename):
        self.terminal = sys.stdout
        self.log = open(filename, 'w')

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()  # Ensure immediate write to file

    def flush(self):
        self.terminal.flush()
        self.log.flush()

    def close(self):
        self.log.close()

class AcupunctureCATEAllocator:
    """Acupuncture CATE allocation algorithm with fixed gamma=0.5 and updated heavy interval threshold."""

    def __init__(self, epsilon=0.1, gamma=0.5, delta=0.05, heavy_multiplier=1.6, random_seed=42):
        self.epsilon = epsilon
        self.gamma = gamma
        self.rho = gamma * np.sqrt(epsilon)
        self.delta = delta
        self.heavy_multiplier = heavy_multiplier
        self.random_seed = random_seed
        np.random.seed(random_seed)

        print(f"Acupuncture CATE Allocation Algorithm")
        print(f"ε = {epsilon}")
        print(f"√ε = {np.sqrt(epsilon):.6f}")
        print(f"γ = {gamma}")
        print(f"ρ = γ√ε = {self.rho:.6f}")
        print(f"Heavy multiplier = {heavy_multiplier}x")
        print(f"δ = {delta}")
        print("="*60)

    def process_acupuncture_data(self, df, outcome_col='pk5', treatment_col='group'):
        """Process acupuncture dataset for analysis."""
        print(f"Processing acupuncture data with {len(df)} patients")
        print(f"Available columns: {list(df.columns)}")

        df_processed = df.copy()

        if treatment_col not in df_processed.columns:
            raise ValueError(f"Missing required treatment column: {treatment_col}")
        if outcome_col not in df_processed.columns:
            raise ValueError(f"Missing required outcome column: {outcome_col}")

        df_processed['treatment'] = df_processed[treatment_col]
        df_processed['outcome'] = df_processed[outcome_col]

        if 'pk1' in df_processed.columns:
            df_processed['baseline_headache'] = df_processed['pk1']
        else:
            df_processed['baseline_headache'] = 0

        initial_size = len(df_processed)
        df_processed = df_processed.dropna(subset=['outcome', 'treatment'])
        final_size = len(df_processed)

        if initial_size != final_size:
            print(f"Dropped {initial_size - final_size} rows due to missing outcome/treatment")

        print(f"Final dataset: {final_size} patients")
        print(f"Treatment distribution: {df_processed['treatment'].value_counts().to_dict()}")
        print(f"Outcome (12-month headache score) statistics: mean={df_processed['outcome'].mean():.2f}, std={df_processed['outcome'].std():.2f}")

        if 'baseline_headache' in df_processed.columns:
            print(f"Baseline headache stats: mean={df_processed['baseline_headache'].mean():.2f}, std={df_processed['baseline_headache'].std():.2f}")

        return df_processed

    def create_age_chronicity_groups(self, df, n_groups=30, min_size=6):
        """Create groups based on age-chronicity interaction."""
        print(f"Creating age-chronicity interaction groups (target: {n_groups})")

        if 'age' not in df.columns or 'chronicity' not in df.columns:
            print("No age or chronicity variables found")
            return []

        age = df['age'].fillna(df['age'].median())
        chronicity = df['chronicity'].fillna(df['chronicity'].median())

        # Create interaction score (normalized age * chronicity)
        age_norm = (age - age.min()) / (age.max() - age.min()) if age.max() > age.min() else age * 0
        chron_norm = (chronicity - chronicity.min()) / (chronicity.max() - chronicity.min()) if chronicity.max() > chronicity.min() else chronicity * 0
        interaction_score = age_norm * chron_norm

        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(interaction_score, percentiles)
        bins = np.digitize(interaction_score, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'age_chronicity_group_{i}',
                    'indices': indices,
                    'type': 'age_chronicity'
                })

        print(f"Created {len(groups)} age-chronicity interaction groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_age_groups(self, df, n_groups=30, min_size=6):
        """Create groups based on age brackets."""
        print(f"Creating age groups (target: {n_groups})")

        if 'age' not in df.columns:
            print("No age variable found")
            return []

        age = df['age'].fillna(df['age'].median())

        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(age, percentiles)
        bins = np.digitize(age, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'age_group_{i}',
                    'indices': indices,
                    'type': 'age'
                })

        print(f"Created {len(groups)} age groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_chronicity_groups(self, df, n_groups=30, min_size=6):
        """Create groups based on headache chronicity."""
        print(f"Creating chronicity groups (target: {n_groups})")

        if 'chronicity' not in df.columns:
            print("No chronicity variable found")
            return []

        chronicity = df['chronicity'].fillna(df['chronicity'].median())

        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(chronicity, percentiles)
        bins = np.digitize(chronicity, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'chronicity_group_{i}',
                    'indices': indices,
                    'type': 'chronicity'
                })

        print(f"Created {len(groups)} chronicity groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_multidimensional_groups(self, df, n_groups=30, min_size=6):
        """Create groups based on composite score of all continuous variables."""
        print(f"Creating multidimensional composite groups (target: {n_groups})")

        continuous_vars = ['age', 'chronicity', 'pk1']
        available_vars = [col for col in continuous_vars if col in df.columns]

        if len(available_vars) < 2:
            print("Not enough continuous variables for multidimensional grouping")
            return []

        print(f"Using continuous variables: {available_vars}")

        # Create normalized composite score
        composite_score = pd.Series(0.0, index=df.index)

        for var in available_vars:
            values = df[var].fillna(df[var].median())
            if values.max() > values.min():
                normalized = (values - values.min()) / (values.max() - values.min())
            else:
                normalized = values * 0
            composite_score += normalized

        # Divide by number of variables to get mean
        composite_score = composite_score / len(available_vars)

        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(composite_score, percentiles)
        bins = np.digitize(composite_score, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'multidim_group_{i}',
                    'indices': indices,
                    'type': 'multidimensional'
                })

        print(f"Created {len(groups)} multidimensional groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_baseline_headache_groups(self, df, n_groups=30, min_size=6):
        """Create groups based on baseline headache scores."""
        print(f"Creating baseline headache groups (target: {n_groups})")

        if 'pk1' not in df.columns:
            print("No baseline headache data available")
            return []

        baseline = df['pk1'].fillna(df['pk1'].median())

        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(baseline, percentiles)
        bins = np.digitize(baseline, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'baseline_headache_{i}',
                    'indices': indices,
                    'type': 'baseline_headache'
                })

        print(f"Created {len(groups)} baseline headache groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_baseline_age_groups(self, df, n_groups=30, min_size=6):
        """Create groups based on baseline headache-age interaction."""
        print(f"Creating baseline-age interaction groups (target: {n_groups})")

        if 'pk1' not in df.columns or 'age' not in df.columns:
            print("No baseline headache or age variables found")
            return []

        baseline = df['pk1'].fillna(df['pk1'].median())
        age = df['age'].fillna(df['age'].median())

        # Create interaction score (normalized baseline * age)
        baseline_norm = (baseline - baseline.min()) / (baseline.max() - baseline.min()) if baseline.max() > baseline.min() else baseline * 0
        age_norm = (age - age.min()) / (age.max() - age.min()) if age.max() > age.min() else age * 0
        interaction_score = baseline_norm * age_norm

        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(interaction_score, percentiles)
        bins = np.digitize(interaction_score, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'baseline_age_group_{i}',
                    'indices': indices,
                    'type': 'baseline_age'
                })

        print(f"Created {len(groups)} baseline-age interaction groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_covariate_forest_groups(self, df, n_groups=30, min_size=6):
        """Create groups using clustering on baseline covariates only."""
        print(f"Creating covariate-based forest groups (target: {n_groups})")

        feature_cols = ['age', 'sex', 'migraine', 'chronicity', 'pk1']  # baseline covariates
        available_features = [col for col in feature_cols if col in df.columns]

        if not available_features:
            print("No features available for covariate clustering")
            return []

        X = df[available_features].copy()

        for col in X.columns:
            if X[col].dtype == 'object':
                le = LabelEncoder()
                X[col] = X[col].fillna('missing')
                X[col] = le.fit_transform(X[col])
            else:
                if X[col].isna().any():
                    X[col] = X[col].fillna(X[col].median())

        cluster_features = StandardScaler().fit_transform(X.values)
        labels = KMeans(n_clusters=n_groups, random_state=self.random_seed).fit_predict(cluster_features)

        groups = []
        for i in range(n_groups):
            indices = df.index[labels == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'covariate_cluster_{i}',
                    'indices': indices,
                    'type': 'covariate_cluster'
                })

        print(f"Created {len(groups)} covariate-based groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_propensity_groups(self, df, n_groups=50, min_size=6):
        """Create groups based on propensity score strata."""
        print(f"Creating propensity score groups (target: {n_groups})")

        feature_cols = ['age', 'sex', 'migraine', 'chronicity', 'pk1']  # baseline covariates
        available_features = [col for col in feature_cols if col in df.columns]

        if not available_features:
            print("No features available for propensity scoring")
            return []

        X = df[available_features].copy()

        for col in X.columns:
            if X[col].dtype == 'object':
                le = LabelEncoder()
                X[col] = X[col].fillna('missing')
                X[col] = le.fit_transform(X[col])
            else:
                if X[col].isna().any():
                    X[col] = X[col].fillna(X[col].median())

        try:
            prop_scores = cross_val_predict(
                LogisticRegression(random_state=self.random_seed, max_iter=1000),
                X, df['treatment'], method='predict_proba', cv=5
            )[:, 1]
        except Exception as e:
            print(f"Error computing propensity scores: {e}")
            return []

        quantiles = np.linspace(0, 1, n_groups + 1)
        bins = np.digitize(prop_scores, np.quantile(prop_scores, quantiles)) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'propensity_{i}',
                    'indices': indices,
                    'type': 'propensity'
                })

        print(f"Created {len(groups)} propensity groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def _ensure_balance_and_compute_cate(self, df, groups):
        """Ensure treatment balance and compute group CATE."""
        balanced_groups = []

        for group in groups:
            group_df = df.loc[group['indices']]

            treatment_rate = group_df['treatment'].mean()
            n_treated = group_df['treatment'].sum()
            n_control = len(group_df) - n_treated

            if not (0.15 <= treatment_rate <= 0.85 and n_treated >= 3 and n_control >= 3):
                continue

            treated_outcomes = group_df[group_df['treatment'] == 1]['outcome']
            control_outcomes = group_df[group_df['treatment'] == 0]['outcome']
            # Reverse sign: lower headache scores = better (treatment benefit)
            cate = -(treated_outcomes.mean() - control_outcomes.mean())

            balanced_groups.append({
                'id': group['id'],
                'indices': group['indices'],
                'size': len(group_df),
                'treatment_rate': treatment_rate,
                'n_treated': int(n_treated),
                'n_control': int(n_control),
                'cate': cate,
                'type': group['type']
            })

        return balanced_groups

    def normalize_cates(self, groups):
        """Normalize CATE values to [0,1]."""
        cates = [g['cate'] for g in groups]
        min_cate, max_cate = min(cates), max(cates)

        if max_cate > min_cate:
            for group in groups:
                group['normalized_cate'] = (group['cate'] - min_cate) / (max_cate - min_cate)
        else:
            for group in groups:
                group['normalized_cate'] = 0.5

        print(f"CATE normalization: [{min_cate:.3f}, {max_cate:.3f}] → [0, 1]")
        return groups

    def plot_cate_distribution(self, groups, title_suffix=""):
        """Plot CATE distribution."""
        original_cates = [g['cate'] for g in groups]
        normalized_cates = [g['normalized_cate'] for g in groups]

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

        ax1.hist(original_cates, bins=15, alpha=0.7, color='skyblue', edgecolor='black')
        ax1.set_xlabel('Original CATE (headache reduction effect)')
        ax1.set_ylabel('Frequency')
        ax1.set_title(f'Original CATE Distribution{title_suffix}')
        ax1.grid(True, alpha=0.3)

        ax2.hist(normalized_cates, bins=15, alpha=0.7, color='lightcoral', edgecolor='black')
        ax2.set_xlabel('Normalized CATE (τ)')
        ax2.set_ylabel('Frequency')
        ax2.set_title(f'Normalized CATE Distribution{title_suffix}')
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

    def estimate_tau(self, true_tau, accuracy):
        """Estimate tau using Hoeffding's inequality with Bernoulli samples."""
        sample_size = int(np.ceil(np.log(2/self.delta) / (2 * accuracy**2)))
        samples = np.random.binomial(1, true_tau, sample_size)
        return np.mean(samples), sample_size

    def run_single_trial(self, groups, epsilon_val, trial_seed):
        """Run allocation algorithm for single trial with fixed gamma."""
        np.random.seed(self.random_seed + trial_seed)

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])
        rho = self.gamma * np.sqrt(epsilon_val)

        tau_estimates_rho = []
        for tau in tau_true:
            estimate, _ = self.estimate_tau(tau, rho)
            tau_estimates_rho.append(estimate)
        tau_estimates_rho = np.array(tau_estimates_rho)

        tau_estimates_eps = []
        for tau in tau_true:
            estimate, _ = self.estimate_tau(tau, epsilon_val)
            tau_estimates_eps.append(estimate)
        tau_estimates_eps = np.array(tau_estimates_eps)

        results = []

        for K in range(1, n_groups):
            optimal_indices = np.argsort(tau_true)[-K:]
            optimal_value = np.sum(tau_true[optimal_indices])

            rho_indices = np.argsort(tau_estimates_rho)[-K:]
            rho_value = np.sum(tau_true[rho_indices])

            eps_indices = np.argsort(tau_estimates_eps)[-K:]
            eps_value = np.sum(tau_true[eps_indices])

            rho_ratio = rho_value / optimal_value if optimal_value > 0 else 0
            eps_ratio = eps_value / optimal_value if optimal_value > 0 else 0
            rho_success = rho_ratio >= (1 - epsilon_val)
            eps_success = eps_ratio >= (1 - epsilon_val)

            tau_k_est = tau_estimates_rho[rho_indices[0]]
            a2_lower = tau_k_est
            a2_upper = tau_k_est + 2 * rho
            units_in_a2 = np.sum((tau_estimates_rho >= a2_lower) & (tau_estimates_rho <= a2_upper))
            expected_a2 = 2 * rho * n_groups
            is_heavy = units_in_a2 > self.heavy_multiplier * expected_a2

            results.append({
                'K': K,
                'optimal_value': optimal_value,
                'rho_value': rho_value,
                'eps_value': eps_value,
                'rho_ratio': rho_ratio,
                'eps_ratio': eps_ratio,
                'rho_success': rho_success,
                'eps_success': eps_success,
                'is_heavy': is_heavy,
                'tau_k_est': tau_k_est,
                'units_in_a2': units_in_a2
            })

        return results, tau_estimates_rho

    def find_recovery_units(self, K, tau_true, tau_estimates, epsilon_val):
        """Find minimum units needed to achieve 1-epsilon performance."""
        n_groups = len(tau_true)

        rho_indices = np.argsort(tau_estimates)[-K:]
        optimal_value = np.sum(tau_true[np.argsort(tau_true)[-K:]])

        remaining_indices = np.argsort(tau_estimates)[:-K][::-1]

        for extra in range(1, 11):
            if extra > len(remaining_indices):
                break

            expanded_indices = np.concatenate([rho_indices, remaining_indices[:extra]])
            expanded_value = np.sum(tau_true[expanded_indices])

            if expanded_value / optimal_value >= (1 - epsilon_val):
                return extra

        return None

    def find_closest_working_budget(self, failed_K, trial_results):
        """Find closest budget that works for a failed budget."""
        working_budgets = [r['K'] for r in trial_results if r['rho_success']]

        if not working_budgets:
            return None, None

        distances_any = [abs(K - failed_K) for K in working_budgets]
        min_distance_any = min(distances_any)

        smaller_working = [K for K in working_budgets if K < failed_K]
        if smaller_working:
            min_distance_smaller = failed_K - max(smaller_working)
        else:
            min_distance_smaller = None

        return min_distance_any, min_distance_smaller

    def analyze_method(self, groups, epsilon_val, n_trials=30):
        """Analyze single method with fixed gamma and updated heavy threshold."""
        print(f"\nAnalyzing {len(groups)} groups with ε={epsilon_val}, γ={self.gamma}")

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])

        trial_data = []

        for trial in range(n_trials):
            print(f"Trial {trial + 1}/{n_trials}...")

            trial_results, tau_estimates = self.run_single_trial(groups, epsilon_val, trial)

            failed_results = [r for r in trial_results if not r['rho_success']]
            failed_budgets = [r['K'] for r in failed_results]

            failed_heavy_estimated = []
            failed_heavy_true = []
            rho = self.gamma * np.sqrt(epsilon_val)

            for failed_result in failed_results:
                K = failed_result['K']
                failed_heavy_estimated.append(failed_result['is_heavy'])

                tau_k_true = tau_true[np.argsort(tau_true)[-K:]][0]
                a2_lower_true = tau_k_true
                a2_upper_true = tau_k_true + 2 * rho
                units_in_a2_true = np.sum((tau_true >= a2_lower_true) & (tau_true <= a2_upper_true))
                expected_a2_true = 2 * rho * n_groups
                is_heavy_true = units_in_a2_true > self.heavy_multiplier * expected_a2_true
                failed_heavy_true.append(is_heavy_true)

            print(f"  Failed budgets: {failed_budgets}")

            if len(failed_budgets) > 0:
                estimated_clean = [bool(x) for x in failed_heavy_estimated]
                true_clean = [bool(x) for x in failed_heavy_true]
                print(f"  HEAVY INTERVALS - Estimated: {estimated_clean}")
                print(f"  HEAVY INTERVALS - True τ_K:   {true_clean}")

            total_heavy = sum(r['is_heavy'] for r in trial_results)
            failed_heavy = sum(r['is_heavy'] for r in failed_results)

            recovery_units = []
            distances_to_working_any = []
            distances_to_working_smaller = []

            for failed_result in failed_results:
                K = failed_result['K']

                recovery = self.find_recovery_units(K, tau_true, tau_estimates, epsilon_val)
                if recovery is not None:
                    recovery_units.append(recovery)

                distance_any, distance_smaller = self.find_closest_working_budget(K, trial_results)
                if distance_any is not None:
                    distances_to_working_any.append(distance_any)
                if distance_smaller is not None:
                    distances_to_working_smaller.append(distance_smaller)

            trial_info = {
                'trial': trial,
                'failed_budgets': failed_budgets,
                'num_failures': len(failed_results),
                'total_heavy': total_heavy,
                'failed_heavy': failed_heavy,
                'failed_heavy_estimated': failed_heavy_estimated,
                'failed_heavy_true': failed_heavy_true,
                'recovery_units': recovery_units,
                'distances_to_working_any': distances_to_working_any,
                'distances_to_working_smaller': distances_to_working_smaller
            }

            trial_data.append(trial_info)

            print(f"  Failures: {len(failed_results)}, Total heavy: {total_heavy}, Failed heavy: {failed_heavy}")
            if recovery_units:
                print(f"  Recovery units: μ={np.mean(recovery_units):.1f}, med={np.median(recovery_units):.0f}, max={np.max(recovery_units)}")
            if distances_to_working_any:
                print(f"  Distance any: μ={np.mean(distances_to_working_any):.1f}, med={np.median(distances_to_working_any):.0f}, max={np.max(distances_to_working_any)}")
            if distances_to_working_smaller:
                print(f"  Distance smaller: μ={np.mean(distances_to_working_smaller):.1f}, med={np.median(distances_to_working_smaller):.0f}, max={np.max(distances_to_working_smaller)}")
            else:
                print(f"  Distance smaller: No smaller working budgets found")

        return trial_data

    def print_method_summary(self, method_name, trial_data, n_groups, epsilon_val):
        """Print summary statistics for a method."""
        budget_10pct_threshold = max(1, int(0.1 * n_groups))

        print(f"\n{'='*100}")
        print(f"SUMMARY - {method_name} - ε={epsilon_val} - {n_groups} GROUPS")
        print("="*100)
        print(f"{'Fail μ':<7} {'Fail σ':<7} {'FailR% μ':<9} {'FailR% σ':<9} {'TotHvy':<8} {'FailHvy':<9} {'Rec μ':<7} {'Rec med':<8} {'Rec max':<8} {'DAny μ':<8} {'DAny σ':<10} {'DAny max':<10} {'DSmall μ':<10} {'DSmall σ':<12} {'DSmall max':<12}")
        print("-"*120)

        all_failures = [t['num_failures'] for t in trial_data]
        all_total_heavy = [t['total_heavy'] for t in trial_data]
        all_failed_heavy = [t['failed_heavy'] for t in trial_data]
        all_recovery = []
        all_distances_any = []
        all_distances_smaller = []

        for t in trial_data:
            all_recovery.extend(t['recovery_units'])
            all_distances_any.extend(t['distances_to_working_any'])
            all_distances_smaller.extend(t['distances_to_working_smaller'])

        avg_failures = np.mean(all_failures)
        std_failures = np.std(all_failures)
        avg_failure_rate = avg_failures / (n_groups - 1) * 100
        std_failure_rate = std_failures / (n_groups - 1) * 100
        avg_total_heavy = np.mean(all_total_heavy)
        avg_failed_heavy = np.mean(all_failed_heavy)

        if all_recovery:
            recovery_mean = np.mean(all_recovery)
            recovery_med = np.median(all_recovery)
            recovery_max = np.max(all_recovery)
        else:
            recovery_mean = recovery_med = recovery_max = np.nan

        if all_distances_any:
            distance_any_mean = np.mean(all_distances_any)
            distance_any_std = np.std(all_distances_any)
            distance_any_max = np.max(all_distances_any)
        else:
            distance_any_mean = distance_any_std = distance_any_max = np.nan

        if all_distances_smaller:
            distance_smaller_mean = np.mean(all_distances_smaller)
            distance_smaller_std = np.std(all_distances_smaller)
            distance_smaller_max = np.max(all_distances_smaller)
        else:
            distance_smaller_mean = distance_smaller_std = distance_smaller_max = np.nan

        print(f"{avg_failures:<7.1f} {std_failures:<7.1f} {avg_failure_rate:<9.1f} {std_failure_rate:<9.1f} {avg_total_heavy:<8.1f} {avg_failed_heavy:<9.1f} "
              f"{recovery_mean:<7.1f} {recovery_med:<8.0f} {recovery_max:<8.0f} "
              f"{distance_any_mean:<8.1f} {distance_any_std:<10.1f} {distance_any_max:<10.0f} "
              f"{distance_smaller_mean:<10.1f} {distance_smaller_std:<12.1f} {distance_smaller_max:<12.0f}")

        return {
            'avg_failures': avg_failures,
            'failure_rate_pct': avg_failure_rate,
            'avg_recovery': recovery_mean,
            'n_groups': n_groups
        }


def run_comprehensive_acupuncture_analysis(df_acupuncture, epsilon_values=None, n_trials=30, log_file=None):
    """Run comprehensive acupuncture analysis with all methods, fixed gamma=0.5, and 1.6x heavy threshold."""

    if epsilon_values is None:
        epsilon_values = [0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.15, 0.2]

    if log_file is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file = f"acupuncture_comprehensive_analysis_gamma05_{timestamp}.txt"

    original_stdout = sys.stdout
    tee = TeeOutput(log_file)
    sys.stdout = tee

    try:
        print("COMPREHENSIVE ACUPUNCTURE ANALYSIS - ALL METHODS, FIXED γ=0.5, HEAVY THRESHOLD=1.6x")
        print(f"Log file: {log_file}")
        print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        print("="*100)

        methods = [
            ('Age-Chronicity Interaction', lambda allocator, df: allocator.create_age_chronicity_groups(df, n_groups=30, min_size=6)),
            ('Baseline-Age Interaction', lambda allocator, df: allocator.create_baseline_age_groups(df, n_groups=30, min_size=6)),
            ('Age Groups', lambda allocator, df: allocator.create_age_groups(df, n_groups=30, min_size=6)),
            ('Chronicity Groups', lambda allocator, df: allocator.create_chronicity_groups(df, n_groups=30, min_size=6)),
            ('Multidimensional Composite', lambda allocator, df: allocator.create_multidimensional_groups(df, n_groups=30, min_size=6)),
            ('Baseline Headache', lambda allocator, df: allocator.create_baseline_headache_groups(df, n_groups=30, min_size=6)),
            ('Covariate Forest 30', lambda allocator, df: allocator.create_covariate_forest_groups(df, n_groups=30, min_size=6)),
            ('Covariate Forest 50', lambda allocator, df: allocator.create_covariate_forest_groups(df, n_groups=50, min_size=6)),
            ('Propensity Score', lambda allocator, df: allocator.create_propensity_groups(df, n_groups=50, min_size=6))
        ]

        all_results = {}

        for method_name, method_func in methods:
            print(f"\n{'='*120}")
            print(f"ANALYZING ACUPUNCTURE METHOD: {method_name}")
            print("="*120)

            method_results = []

            for eps in epsilon_values:
                print(f"\n{'='*100}")
                print(f"METHOD: {method_name} | EPSILON = {eps}")
                print("="*100)

                allocator = AcupunctureCATEAllocator(epsilon=eps, gamma=0.5, heavy_multiplier=1.6)
                df_processed = allocator.process_acupuncture_data(df_acupuncture)

                try:
                    groups = method_func(allocator, df_processed)

                    if len(groups) < 3:
                        print(f"Too few groups ({len(groups)}) for {method_name} with ε = {eps} - skipping")
                        continue

                    groups = allocator.normalize_cates(groups)

                    allocator.plot_cate_distribution(groups, f" ({method_name}, ε={eps})")

                    trial_data = allocator.analyze_method(groups, eps, n_trials)

                    stats = allocator.print_method_summary(method_name, trial_data, len(groups), eps)

                    epsilon_result = {
                        'method': method_name,
                        'epsilon': eps,
                        'sqrt_epsilon': np.sqrt(eps),
                        'gamma': 0.5,
                        'rho': 0.5 * np.sqrt(eps),
                        'groups': groups,
                        'trial_data': trial_data,
                        'stats': stats
                    }

                    method_results.append(epsilon_result)

                except Exception as e:
                    print(f"Error with {method_name} at ε = {eps}: {e}")
                    continue

            all_results[method_name] = method_results

            if method_results:
                print(f"\n{'='*120}")
                print(f"METHOD SUMMARY - {method_name} - ALL EPSILON VALUES")
                print("="*120)
                print(f"{'ε':<8} {'√ε':<10} {'γ':<6} {'ρ':<10} {'Groups':<8} {'Fail μ':<8} {'FailR%':<8} {'Rec μ':<8}")
                print("-" * 80)

                for eps_result in method_results:
                    eps = eps_result['epsilon']
                    sqrt_eps = eps_result['sqrt_epsilon']
                    gamma = eps_result['gamma']
                    rho = eps_result['rho']
                    n_groups = len(eps_result['groups'])
                    stats = eps_result['stats']

                    print(f"{eps:<8} {sqrt_eps:<10.6f} {gamma:<6} {rho:<10.6f} "
                          f"{n_groups:<8} {stats['avg_failures']:<8.1f} {stats['failure_rate_pct']:<8.1f} "
                          f"{stats['avg_recovery']:<8.1f}")
                print("="*120)

        print(f"\n{'='*200}")
        print("COMPREHENSIVE SUMMARY - ALL ACUPUNCTURE METHODS AND EPSILON VALUES")
        print("="*200)

        summary_data = []

        for method_name, method_results in all_results.items():
            if not method_results:
                continue

            print(f"\n{'-'*100}")
            print(f"ACUPUNCTURE METHOD: {method_name}")
            print("-"*100)

            for eps_result in method_results:
                eps = eps_result['epsilon']
                sqrt_eps = eps_result['sqrt_epsilon']
                gamma = eps_result['gamma']
                rho = eps_result['rho']
                n_groups = len(eps_result['groups'])
                stats = eps_result['stats']

                summary_data.append({
                    'method': method_name,
                    'epsilon': eps,
                    'sqrt_eps': sqrt_eps,
                    'gamma': gamma,
                    'rho': rho,
                    'avg_failures': stats['avg_failures'],
                    'failure_rate_pct': stats['failure_rate_pct'],
                    'avg_recovery': stats['avg_recovery'],
                    'n_groups': stats['n_groups']
                })

            method_data = [d for d in summary_data if d['method'] == method_name]
            if method_data:
                print(f"{'ε':<8} {'√ε':<10} {'γ':<6} {'ρ':<10} {'Groups':<8} {'Fail μ':<8} {'FailR%':<8} {'Rec μ':<8}")
                print("-" * 80)

                for data in method_data:
                    print(f"{data['epsilon']:<8} {data['sqrt_eps']:<10.6f} {data['gamma']:<6} {data['rho']:<10.6f} "
                          f"{data['n_groups']:<8} {data['avg_failures']:<8.1f} {data['failure_rate_pct']:<8.1f} "
                          f"{data['avg_recovery']:<8.1f}")

        print(f"\n{'='*200}")
        print("OVERALL SUMMARY TABLE - ALL ACUPUNCTURE METHODS COMBINED")
        print("="*200)
        print(f"{'Method':<18} {'ε':<8} {'√ε':<10} {'γ':<6} {'ρ':<10} {'Groups':<8} {'Fail μ':<8} {'FailR%':<8} {'Rec μ':<8}")
        print("-" * 100)

        for data in summary_data:
            print(f"{data['method']:<18} {data['epsilon']:<8} {data['sqrt_eps']:<10.6f} {data['gamma']:<6} {data['rho']:<10.6f} "
                  f"{data['n_groups']:<8} {data['avg_failures']:<8.1f} {data['failure_rate_pct']:<8.1f} "
                  f"{data['avg_recovery']:<8.1f}")

        print(f"\n{'='*100}")
        print("KEY INSIGHTS FOR ACUPUNCTURE DATASET")
        print("="*100)

        if summary_data:
            method_performance = {}
            for method_name in all_results.keys():
                method_data = [d for d in summary_data if d['method'] == method_name]
                if method_data:
                    avg_failure_rate = np.mean([d['failure_rate_pct'] for d in method_data])
                    method_performance[method_name] = avg_failure_rate

            if method_performance:
                best_method = min(method_performance, key=method_performance.get)
                worst_method = max(method_performance, key=method_performance.get)

                print(f"BEST PERFORMING ACUPUNCTURE METHOD: {best_method}")
                print(f"  Average failure rate: {method_performance[best_method]:.1f}%")

                print(f"\nWORST PERFORMING ACUPUNCTURE METHOD: {worst_method}")
                print(f"  Average failure rate: {method_performance[worst_method]:.1f}%")

                print(f"\nACUPUNCTURE METHOD RANKING (by average failure rate):")
                sorted_methods = sorted(method_performance.items(), key=lambda x: x[1])
                for i, (method, rate) in enumerate(sorted_methods, 1):
                    print(f"  {i}. {method}: {rate:.1f}%")

        print(f"\nEFFECT OF EPSILON ON ACUPUNCTURE DATA:")
        epsilon_performance = {}
        for eps in epsilon_values:
            eps_data = [d for d in summary_data if d['epsilon'] == eps]
            if eps_data:
                avg_failure_rate = np.mean([d['failure_rate_pct'] for d in eps_data])
                epsilon_performance[eps] = avg_failure_rate

        if epsilon_performance:
            print(f"{'Epsilon':<10} {'Avg Failure Rate':<15} {'ρ = 0.5√ε':<12}")
            print("-" * 40)
            for eps in sorted(epsilon_performance.keys()):
                rho = 0.5 * np.sqrt(eps)
                print(f"{eps:<10} {epsilon_performance[eps]:<15.1f} {rho:<12.6f}")

        return all_results, summary_data

    finally:
        sys.stdout = original_stdout
        tee.close()


# Example usage for acupuncture dataset
if __name__ == "__main__":
    # Load acupuncture dataset and run the analysis
    df_acupuncture = pd.read_stata('acupuncture.dta')

    # Run comprehensive acupuncture analysis with same parameters as NSW and medical
    epsilon_values = [0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.15, 0.2]

    results, summary = run_comprehensive_acupuncture_analysis(
        df_acupuncture,
        epsilon_values=epsilon_values,
        n_trials=30,
        log_file="acupuncture_comprehensive_analysis_gamma05_heavy16.txt"
    )